#ifndef AHLGREN_SATSOLVER
#define AHLGREN_SATSOLVER

#include "propositional.h"
#include <iostream>
#include <vector>
#include <map>
#include <set>
#include <functional> // greater<double> for VSIDS

namespace prolog {
	using namespace std;

	// Bit String Type Definition
	typedef int bit_t;
	typedef vector<bit_t> bitstring;
	ostream& operator<<(ostream& os, const bitstring& b);


	// *** Define Watched Clause class *** //
	class WClause {
	public:
		typedef set<int>::iterator iterator;
		typedef set<int>::const_iterator const_iterator;
		WClause() : watch(0,0) {}
		WClause(int l) : watch(l,0) { lits.insert(l); }
		WClause(int x, int y) : watch(x,y) { lits.insert(x); lits.insert(y); }
		WClause(const bitstring& bs);

		template <typename Iter> WClause(Iter beg, Iter end) : watch(0,0) {
			lits.insert(beg,end);  
			if (beg != end) {
				watch.first = *beg;
				if (++beg != end) watch.second = *beg;
			}
		}

		WClause(WClause&& c) : lits(move(c.lits)), watch(move(c.watch)) { }
		WClause& operator=(WClause&& c) { if (this != &c) { lits = move(c.lits); watch = move(c.watch); } return *this; }
		// Add literal with value
		void insert(int lit);
		// Erase literal, give back pointer to watch list
		bool erase(int l) { return lits.erase(l) != 0; }
		// Clear all literals
		void clear() { lits.clear(); watch.first = watch.second = 0; }
		// Is it a unit clause? (returns the literal or 0 if none)
		int unit() const { return (lits.size() == 1) ? *lits.begin() : 0; }
		// Is it the empty clause?
		bool empty() const { return lits.empty(); }
		// Get clause size
		unsigned size() const { return lits.size(); }

		// Has literal?
		bool has(int l) const { return lits.find(l) != lits.end(); }
		// Get sign of literal (0 = not present)
		int sign(int l) const;
		// Comparison operators (ignoring mutable watchlists)
		bool operator==(const WClause& c) const { return lits == c.lits; }
		bool operator!=(const WClause& c) const { return !(lits == c.lits); }
		// Iterator Access
		iterator begin() { return lits.begin(); }
		iterator end() { return lits.end(); }
		const_iterator begin() const { return lits.begin(); }
		const_iterator end() const { return lits.end(); }
		// Are we watching two literals?
		bool watching() const { return watch.first != 0 && watch.second != 0; }
		// Get watched literals
		pair<int,int>& watched() { return watch; }
		pair<int,int> watched() const { return watch; }
		int& watched1() { return watch.first; }
		int watched1() const { return watch.first; }
		int& watched2() { return watch.second; }
		int watched2() const { return watch.second; }
		int other_watched(int l) const;
		bool switch_watched(int lo, int ln);

		// Logical
		bool satisfied(const vector<pair<int,bool> >& va) const;

		// Print
		void print(ostream& os) const;
	protected:
		set<int> lits;
		pair<int,int> watch;
	};

	inline ostream& operator<<(ostream& os, const WClause& c) { c.print(os); return os; }
	// Implication operators
	inline bool operator>>(const vector<pair<int,bool> >& v, const WClause& c) { return c.satisfied(v); }
	inline bool operator<<(const WClause& c, const vector<pair<int,bool> >& v) { return c.satisfied(v); }


	//// *** STL set does not allow changing elements, so implement counter *** //
	//struct VCNode {
	//	VCNode(int i, double d = 0.0) : l(i), c(d), nl(nullptr), nm(nullptr) {}
	//	unsigned size() const {
	//		unsigned s = 1;
	//		if (less) s += less->size();
	//		if (more) s += more->size();
	//		return s;
	//	}
	//	void half() {
	//		c /= 2.0;
	//		if (less) less->half();
	//		if (more) more->half();
	//	}
	//	int l; // literal
	//	double c; // counter
	//	VCNode* less; // subtree with values less than this node
	//	VCNode* more; // subtree with values more than this node
	//};
	//class VCounter {
	//public:
	//	VCounter() : root(nullptr) {}
	//	unsigned size() const { return root ? root->size() : 0; }
	//	bool empty() const { return root == nullptr; }
	//	VCNode* find(int l) {
	//		VCNode* ptr = root;
	//		for (;;) {
	//			if (l < ptr->l) {
	//				if (ptr->less) ptr = ptr->less;
	//				else return nullptr; // not found
	//			} else if (l > ptr->l) {
	//				if (ptr->more) ptr = ptr->more;
	//				else return nullptr; // not found
	//			} else break; // found it
	//		}
	//		return ptr;
	//	}
	//	bool insert(int l) {
	//		VCNode* ptr = root;
	//		for (;;) {
	//			if (l < ptr->l) {
	//				ptr->less ? ptr = ptr->less : ptr->less = new VCNode(l);
	//			} else if (l > ptr->l) {
	//				ptr->more ? ptr = ptr->more : ptr->more = new VCNode(l);
	//			} else return false; // already present
	//		}
	//	}
	//	bool increase(int l) {
	//		if (!root) return false;
	//		VCNode* parent = nullptr;
	//		VCNode* ptr = root;
	//		for (;;) {
	//			if (l < ptr->l) {
	//				if (ptr->less) { parent = ptr; ptr = ptr->less; }
	//				else return false; // not found
	//			} else if (l > ptr->l) {
	//				if (ptr->more) { parent = ptr; ptr = ptr->more; }
	//				else return false; // not found
	//			} else {
	//				break; // found it
	//			}
	//		}
	//		// Detach it
	//		if (parent->less == ptr) {
	//			parent->less = nullptr;
	//		} else {
	//			assert(parent->more == ptr);
	//			parent->more == nullptr;
	//		}
	//		// Increase it
	//		++(ptr->c);
	//		// Move this node to appropriate location
	//	}
	//protected:
	//	VCNode* root;
	//};


	// Function object to keep Clause DB counter sorted
	//struct counter_cmp {
	//	// Bigger counter first, if equal, then lowest variable
	//	bool operator()(const pair<int,double>& x, const pair<int,double>& y) const
	//	{
	//		if (x.second > y.second) return true;
	//		else if (x.second < y.second) return false;
	//		else if (x.first == -y.first) return x.first < 0;
	//		else return abs(x.first) < abs(y.first);
	//	}
	//};

	// *** Define Clause Database class *** //
	class ClauseDB {
	public:
		// Define clause database
		typedef vector<WClause> db_type;
		// Define clause iterator
		typedef db_type::iterator iterator;
		typedef db_type::const_iterator const_iterator;
		// Define Watch List Type
		//typedef multimap<int,iterator> wlist_type;
		// Define watch list iterator
		//typedef wlist_type::iterator wliter;

		ClauseDB() { db.reserve(128); } // reserve space for at least 128 clauses
		// Add WClause to DB if consistent, returns false otherwise
		//bool insert(const WClause&);
		bool insert(WClause&&);
		//bool insert(const bitstring& bs);
		// Erase clause
		iterator erase(iterator i) { return db.erase(i); }
		// Simplify DB by turning unit clauses into assignments and propagating
		bool simplify();
		// Use DPLL to retrieve model, or return false is unsatisfiable
		// DPLL does not propagate pre-assignments, see simplify() above
		bool dpll(vector<pair<int,bool> >& model, unsigned model_size);
		// Use DPLL to retrieve bitstring
		bool dpll(bitstring& b, unsigned s);
		// Use DPLL for pure SAT checking
		bool dpll(unsigned s) { vector<pair<int,bool> > v; return dpll(v,s); }

		// Logical
		bool satisfied(const vector<pair<int,bool> >& va) const;

		// Print database
		void print(ostream&) const;
	protected:
		vector<WClause> db;
		//wlist_type wlist;
		vector<pair<int,bool> > va;

	};

	inline ostream& operator<<(ostream& os, const ClauseDB& cdb) { cdb.print(os); return os; }
	// Implication operators
	inline bool operator>>(const vector<pair<int,bool> >& v, const ClauseDB& db) { return db.satisfied(v); }
	inline bool operator<<(const ClauseDB& db, const vector<pair<int,bool> >& v) { return db.satisfied(v); }


	// ************ Define Helper Functions ****************** //

	// Make functor from bit string and bottom clause
	template <typename T>
	Functor<T>* make_functor_ptr(const Functor<T>* bot, const bitstring& bs, bool include_masks = false)
	{
		Functor<T>* f = bot->head()->copy();
		for (unsigned k = 0; k < bs.size(); ++k) {
			if (bs[k] == 1) f->body_push_back( bot->body(k)->copy() );
			else if (include_masks) {
				Functor<T>* l = bot->body(k)->copy();
				l->symbol().insert(0,"_#");
				f->body_push_back(l);
			}
		}
		return f;
	}

	template <typename T>
	Functor<T> make_functor(const Functor<T>& bot, const bitstring& bs, bool include_masks = false)
	{
		Functor<T> f = *bot.head();
		for (unsigned k = 0; k < bs.size(); ++k) {
			if (bs[k] == 1) f.body_push_back( bot.body(k)->copy() );
			else if (include_masks) {
				Functor<T> l = *bot.body(k)->copy();
				l.symbol().insert(0,"_#");
				f.body_push_back(&l);
			}
		}
		return f;
	}

	// Convert from model to bit string
	bitstring make_bitstring(const Model& m);

	// Remaining subspace after pruning functor space
	WClause make_complement_clause(const bitstring& bs, int consistent);

	template <typename T>
	WClause make_complement_clause(const Functor<T>& f, int consistent)
	{
		WClause c;
		const int bsize = static_cast<int>(f.body_size());
		// f = a,-b,-c => 
		if (consistent < 0) {
			// Prune up [0,1,1,0] => 0**0 => -b1 /\ -b4 => allowed: b1 \/ b4
			for (int k = 0; k < bsize; ++k) {
				if (f.body(k)->symbol().compare(0,2,"_#") == 0) {
					c.insert(k+1);
				}
			}
		} else if (consistent > 0) {
			// Prune down [0,1,1,0] => *11* => b2 /\ b3 => allowed: -b2 \/ -b3
			for (int k = 0; k < bsize; ++k) {
				if (f.body(k)->symbol().compare(0,2,"_#") != 0) {
					c.insert(-k-1);
				}
			}
		} else {
			// Prune one [0,1,1,0] => 0110 => -b1 /\ b2 /\ b3 /\ -b4 => b1 \/ -b2 \/ -b3 \/ b4
			for (int k = 0; k < bsize; ++k) {
				if (f.body(k)->symbol().compare(0,2,"_#") == 0) {
					c.insert(k+1);
				} else {
					c.insert(-k-1);
				}
			}
		}
		return c;
	}

	// Remaining subspace after pruning functor space
	template <typename T, typename Iter>
	void make_mode_clause(
		const Functor<T>& f, 
		const map<Functor<T>,Mode<T> >& fm,
		ClauseDB& db)
	{
		// h <- a,b,c
		// X->Y iff X has output var that is in Y
		// If X is active, then it depends on having its input variables instantiated
		// not(X) OR ((A or B) and (C or D)) = -X \/ (C1 /\ C2 /\ ...)
		// = (-X \/ C1) /\ (-X \/ C2) /\ ...

		// Store outputs of each body literal
		vector<set<string> > out_var_vec;
		for_each(f.body_begin(),f.body_end(),[&](const Functor<T>* l){
			const Mode<T>& omode = fm.find(*l)->second;
			set<string> ovars;
			omode.variables(Mode<T>::output,l,ovars);
			out_var_vec.push_back(move(ovars));
		});

		// Grab all input variables of X
		const int f_body_size = static_cast<int>(f.body_size());
		for (int k = 1; k < f_body_size; ++k) { // skip first literal since it necessarily takes input from head
			const Functor<T>* l = f.body(k);
			//cerr << "Input variables of " << *l << ":\n";
			const Mode<T>& imode = fm.find(*l)->second;
			set<string> ivars;
			imode.variables(Mode<T>::input,l,ivars);
			for (auto v = ivars.begin(); v != ivars.end(); ++v) {
				//cerr << "Input variable: " << *v << " <- ";
				// For each input variable, find all possible output literals
				WClause lout; // literals with corresponding output variable
				for (int n = 0; n < k; ++n) {
					const set<string>& ovars = out_var_vec[n];
					if (ovars.find(*v) != ovars.end()) {
						//cerr << n << " ";
						lout.insert(n+1);
					}
				}
				//cerr << "\n";
				// If no bindings were made, this literal takes input directly from head only
				if (!lout.empty()) {
					lout.insert(-k-1); // either it is not present, or we have all input literals
					db.insert(move(lout));
				}
			}
		}
	}


	// *** Define inlined methods *** //

	inline void WClause::insert(int lit) 
	{ 
		lits.insert(lit);
		if (watch.second == 0) {
			if (watch.first == 0) watch.first = lit;
			else watch.second = lit;
		} // else: don't change watch list
	}

	inline int WClause::sign(int l) const 
	{ 
		if (lits.find(l) != lits.end()) return 1; 
		else if (lits.find(-l) != lits.end()) return -1;
		else return 0;
	}

	inline int WClause::other_watched(int l) const 
	{
		if (watch.first==l) return watch.second; 
		else if (watch.second==l) return watch.first;
		else return 0;
	}

	inline bool WClause::switch_watched(int lo, int ln)
	{
		if (watch.first == lo) {
			watch.first = ln;
			return true;
		} else if (watch.second == lo) {
			watch.second = ln;
			return true;
		} else return false;
	}



}

#endif

